/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"

namespace tensorflow {

using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

namespace {

Status ScalarInputsAndOutputs(InferenceContext* c) {
  ShapeHandle unused;
  for (int i = 0; i < c->num_inputs(); ++i) {
    TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
  }
  for (int i = 0; i < c->num_outputs(); ++i) {
    c->set_output(i, c->Scalar());
  }
  return Status::OK();
}

Status TwoElementVectorAndScalarOutputs(InferenceContext* c) {
  ShapeHandle handle;
  DimensionHandle unused_handle;
  for (int i = 0; i < c->num_inputs(); ++i) {
    TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
    TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
  }
  for (int i = 0; i < c->num_outputs(); ++i) {
    c->set_output(i, c->Scalar());
  }
  return Status::OK();
}

Status TwoElementOutput(InferenceContext* c) {
  c->set_output(0, c->Vector(2));
  return Status::OK();
}

}  // namespace

REGISTER_OP("SaveV2")
    .Input("prefix: string")
    .Input("tensor_names: string")
    .Input("shape_and_slices: string")
    .Input("tensors: dtypes")
    .Attr("dtypes: list(type)")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      ShapeHandle s;
      DimensionHandle unused_dim;

      // Validate prefix.
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));

      // Validate tensor_names and shapes_and_slices.
      for (int i = 1; i <= 2; ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
        TF_RETURN_IF_ERROR(
            c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
      }
      // TODO(mrry): Attempt to parse the shapes_and_slices values and use
      // them to constrain the shape of the remaining inputs.
      return Status::OK();
    });

REGISTER_OP("RestoreV2")
    .Input("prefix: string")
    .Input("tensor_names: string")
    .Input("shape_and_slices: string")
    .Output("tensors: dtypes")
    .Attr("dtypes: list(type)")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle shape0, shape1, shape2;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
      TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));

      // Attempt to infer output shapes from its shape_and_slice input.
      const Tensor* shape_and_slices_tensor = c->input_tensor(2);
      if (shape_and_slices_tensor) {
        const auto& shape_and_slices_flat =
            shape_and_slices_tensor->flat<string>();
        if (shape_and_slices_flat.size() != c->num_outputs()) {
          return errors::InvalidArgument(
              "The number of shape_and_slice doesn't match tensor outputs.");
        }
        for (int i = 0; i < shape_and_slices_flat.size(); ++i) {
          const string& shape_and_slice = shape_and_slices_flat(i);
          if (shape_and_slice.empty()) {
            c->set_output(i, c->UnknownShape());
            continue;
          }
          TensorShape parsed_full_shape;
          TensorSlice parsed_slice;
          TensorShape parsed_slice_shape;
          TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
              shape_and_slice, &parsed_full_shape, &parsed_slice,
              &parsed_slice_shape));
          ShapeHandle shape_handle;
          TF_RETURN_IF_ERROR(
              c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
          c->set_output(i, shape_handle);
        }
        return Status::OK();
      } else {
        return UnknownShape(c);
      }
    });

REGISTER_OP("MergeV2Checkpoints")
    .Input("checkpoint_prefixes: string")
    .Input("destination_prefix: string")
    .Attr("delete_old_dirs: bool = true")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
      return Status::OK();
    });

REGISTER_OP("Save")
    .Input("filename: string")
    .Input("tensor_names: string")
    .Input("data: T")
    .Attr("T: list(type)")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      ShapeHandle s;
      DimensionHandle unused_dim;

      // Validate filename.
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));

      // Validate tensor_names.
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &s));
      TF_RETURN_IF_ERROR(
          c->WithValue(c->Dim(s, 0), c->num_inputs() - 2, &unused_dim));

      return Status::OK();
    });

REGISTER_OP("SaveSlices")
    .Input("filename: string")
    .Input("tensor_names: string")
    .Input("shapes_and_slices: string")
    .Input("data: T")
    .Attr("T: list(type)")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      ShapeHandle s;
      DimensionHandle unused_dim;

      // Validate filename.
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));

      // Validate tensor_names and unused_shapes_and_slices.
      for (int i = 1; i <= 2; ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
        TF_RETURN_IF_ERROR(
            c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
      }
      // TODO(mrry): Attempt to parse the shapes_and_slices values and use
      // them to constrain the shape of the remaining inputs.
      return Status::OK();
    });

REGISTER_OP("Restore")
    .Input("file_pattern: string")
    .Input("tensor_name: string")
    .Output("tensor: dt")
    .Attr("dt: type")
    .Attr("preferred_shard: int = -1")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
      c->set_output(0, c->UnknownShape());
      return Status::OK();
    });

REGISTER_OP("RestoreSlice")
    .Input("file_pattern: string")
    .Input("tensor_name: string")
    .Input("shape_and_slice: string")
    .Output("tensor: dt")
    .Attr("dt: type")
    .Attr("preferred_shard: int = -1")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));

      // Attempt to infer output shapes from its shape_and_slice input.
      const Tensor* shape_and_slices_tensor = c->input_tensor(2);
      if (shape_and_slices_tensor) {
        const auto& shape_and_slice =
            shape_and_slices_tensor->flat<string>()(0);
        if (shape_and_slice.empty()) {
          c->set_output(0, c->UnknownShape());
        } else {
          TensorShape parsed_full_shape;
          TensorSlice parsed_slice;
          TensorShape parsed_slice_shape;
          TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
              shape_and_slice, &parsed_full_shape, &parsed_slice,
              &parsed_slice_shape));
          ShapeHandle shape_handle;
          TF_RETURN_IF_ERROR(
              c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
          c->set_output(0, shape_handle);
        }
      } else {
        c->set_output(0, c->UnknownShape());
      }
      return Status::OK();
    });

REGISTER_OP("ShardedFilename")
    .Input("basename: string")
    .Input("shard: int32")
    .Input("num_shards: int32")
    .Output("filename: string")
    .SetShapeFn(ScalarInputsAndOutputs);

REGISTER_OP("ShardedFilespec")
    .Input("basename: string")
    .Input("num_shards: int32")
    .Output("filename: string")
    .SetShapeFn(ScalarInputsAndOutputs);

// Reader source ops ----------------------------------------------------------

REGISTER_OP("WholeFileReader")
    .Output("reader_handle: Ref(string)")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(TwoElementOutput);

REGISTER_OP("WholeFileReaderV2")
    .Output("reader_handle: resource")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(shape_inference::ScalarShape);

REGISTER_OP("TextLineReader")
    .Output("reader_handle: Ref(string)")
    .Attr("skip_header_lines: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(TwoElementOutput)
    .Deprecated(26, "Use TextLineReaderV2");

REGISTER_OP("TextLineReaderV2")
    .Output("reader_handle: resource")
    .Attr("skip_header_lines: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(shape_inference::ScalarShape);

REGISTER_OP("FixedLengthRecordReader")
    .Output("reader_handle: Ref(string)")
    .Attr("header_bytes: int = 0")
    .Attr("record_bytes: int")
    .Attr("footer_bytes: int = 0")
    .Attr("hop_bytes: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(TwoElementOutput)
    .Deprecated(26, "Use FixedLengthRecordReaderV2");

REGISTER_OP("FixedLengthRecordReaderV2")
    .Output("reader_handle: resource")
    .Attr("header_bytes: int = 0")
    .Attr("record_bytes: int")
    .Attr("footer_bytes: int = 0")
    .Attr("hop_bytes: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .Attr("encoding: string = ''")
    .SetIsStateful()
    .SetShapeFn(shape_inference::ScalarShape);

REGISTER_OP("TFRecordReader")
    .Output("reader_handle: Ref(string)")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .Attr("compression_type: string = ''")
    .SetIsStateful()
    .SetShapeFn(TwoElementOutput)
    .Deprecated(26, "Use TFRecordReaderV2");

REGISTER_OP("TFRecordReaderV2")
    .Output("reader_handle: resource")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .Attr("compression_type: string = ''")
    .SetIsStateful()
    .SetShapeFn(shape_inference::ScalarShape);

REGISTER_OP("LMDBReader")
    .Output("reader_handle: Ref(string)")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(TwoElementOutput);

REGISTER_OP("IdentityReader")
    .Output("reader_handle: Ref(string)")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(TwoElementOutput)
    .Deprecated(26, "Use IdentityReaderV2");

REGISTER_OP("IdentityReaderV2")
    .Output("reader_handle: resource")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(shape_inference::ScalarShape);

// Ops that operate on Readers ------------------------------------------------

REGISTER_OP("ReaderRead")
    .Input("reader_handle: Ref(string)")
    .Input("queue_handle: Ref(string)")
    .Output("key: string")
    .Output("value: string")
    .SetShapeFn(TwoElementVectorAndScalarOutputs);

REGISTER_OP("ReaderReadV2")
    .Input("reader_handle: resource")
    .Input("queue_handle: resource")
    .Output("key: string")
    .Output("value: string")
    .SetShapeFn(ScalarInputsAndOutputs);

REGISTER_OP("ReaderReadUpTo")
    .Input("reader_handle: Ref(string)")
    .Input("queue_handle: Ref(string)")
    .Input("num_records: int64")
    .Output("keys: string")
    .Output("values: string")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
      ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
      c->set_output(0, out);
      c->set_output(1, out);
      return Status::OK();
    });

REGISTER_OP("ReaderReadUpToV2")
    .Input("reader_handle: resource")
    .Input("queue_handle: resource")
    .Input("num_records: int64")
    .Output("keys: string")
    .Output("values: string")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
      ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
      c->set_output(0, out);
      c->set_output(1, out);
      return Status::OK();
    });

REGISTER_OP("ReaderNumRecordsProduced")
    .Input("reader_handle: Ref(string)")
    .Output("records_produced: int64")
    .SetShapeFn(TwoElementVectorAndScalarOutputs);

REGISTER_OP("ReaderNumRecordsProducedV2")
    .Input("reader_handle: resource")
    .Output("records_produced: int64")
    .SetShapeFn(ScalarInputsAndOutputs);

REGISTER_OP("ReaderNumWorkUnitsCompleted")
    .Input("reader_handle: Ref(string)")
    .Output("units_completed: int64")
    .SetShapeFn(TwoElementVectorAndScalarOutputs);

REGISTER_OP("ReaderNumWorkUnitsCompletedV2")
    .Input("reader_handle: resource")
    .Output("units_completed: int64")
    .SetShapeFn(ScalarInputsAndOutputs);

REGISTER_OP("ReaderSerializeState")
    .Input("reader_handle: Ref(string)")
    .Output("state: string")
    .SetShapeFn(TwoElementVectorAndScalarOutputs);

REGISTER_OP("ReaderSerializeStateV2")
    .Input("reader_handle: resource")
    .Output("state: string")
    .SetShapeFn(ScalarInputsAndOutputs);

REGISTER_OP("ReaderRestoreState")
    .Input("reader_handle: Ref(string)")
    .Input("state: string")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
      DimensionHandle unused_handle;
      TF_RETURN_IF_ERROR(
          c->WithValue(c->Dim(c->input(0), 0), 2, &unused_handle));

      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
      return Status::OK();
    });

REGISTER_OP("ReaderRestoreStateV2")
    .Input("reader_handle: resource")
    .Input("state: string")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
      return Status::OK();
    });

REGISTER_OP("ReaderReset")
    .Input("reader_handle: Ref(string)")
    .SetShapeFn(TwoElementVectorAndScalarOutputs);

REGISTER_OP("ReaderResetV2")
    .Input("reader_handle: resource")
    .SetShapeFn(ScalarInputsAndOutputs);

// Other input Ops ----------------------------------------------------------

REGISTER_OP("ReadFile")
    .Input("filename: string")
    .Output("contents: string")
    .SetShapeFn(ScalarInputsAndOutputs);

REGISTER_OP("WriteFile")
    .Input("filename: string")
    .Input("contents: string")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
      return Status::OK();
    });

REGISTER_OP("MatchingFiles")
    .Input("pattern: string")
    .Output("filenames: string")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle unused;
      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
      c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
      return Status::OK();
    });

}  // namespace tensorflow
